"""Loading datasets and evaluators."""

from collections.abc import Sequence
from typing import Any

from langchain_core.language_models import BaseLanguageModel

from langchain_classic.chains.base import Chain
from langchain_classic.evaluation.agents.trajectory_eval_chain import (
    TrajectoryEvalChain,
)
from langchain_classic.evaluation.comparison import PairwiseStringEvalChain
from langchain_classic.evaluation.comparison.eval_chain import (
    LabeledPairwiseStringEvalChain,
)
from langchain_classic.evaluation.criteria.eval_chain import (
    CriteriaEvalChain,
    LabeledCriteriaEvalChain,
)
from langchain_classic.evaluation.embedding_distance.base import (
    EmbeddingDistanceEvalChain,
    PairwiseEmbeddingDistanceEvalChain,
)
from langchain_classic.evaluation.exact_match.base import ExactMatchStringEvaluator
from langchain_classic.evaluation.parsing.base import (
    JsonEqualityEvaluator,
    JsonValidityEvaluator,
)
from langchain_classic.evaluation.parsing.json_distance import JsonEditDistanceEvaluator
from langchain_classic.evaluation.parsing.json_schema import JsonSchemaEvaluator
from langchain_classic.evaluation.qa import (
    ContextQAEvalChain,
    CotQAEvalChain,
    QAEvalChain,
)
from langchain_classic.evaluation.regex_match.base import RegexMatchStringEvaluator
from langchain_classic.evaluation.schema import (
    EvaluatorType,
    LLMEvalChain,
    StringEvaluator,
)
from langchain_classic.evaluation.scoring.eval_chain import (
    LabeledScoreStringEvalChain,
    ScoreStringEvalChain,
)
from langchain_classic.evaluation.string_distance.base import (
    PairwiseStringDistanceEvalChain,
    StringDistanceEvalChain,
)


def load_dataset(uri: str) -> list[dict]:
    """Load a dataset from the [LangChainDatasets on HuggingFace](https://huggingface.co/LangChainDatasets).

    Args:
        uri: The uri of the dataset to load.

    Returns:
        A list of dictionaries, each representing a row in the dataset.

    **Prerequisites**

    ```bash
    pip install datasets
    ```

    Examples:
    --------
    ```python
    from langchain_classic.evaluation import load_dataset

    ds = load_dataset("llm-math")
    ```
    """
    try:
        from datasets import load_dataset
    except ImportError as e:
        msg = (
            "load_dataset requires the `datasets` package."
            " Please install with `pip install datasets`"
        )
        raise ImportError(msg) from e

    dataset = load_dataset(f"LangChainDatasets/{uri}")
    return list(dataset["train"])


_EVALUATOR_MAP: dict[
    EvaluatorType,
    type[LLMEvalChain] | type[Chain] | type[StringEvaluator],
] = {
    EvaluatorType.QA: QAEvalChain,
    EvaluatorType.COT_QA: CotQAEvalChain,
    EvaluatorType.CONTEXT_QA: ContextQAEvalChain,
    EvaluatorType.PAIRWISE_STRING: PairwiseStringEvalChain,
    EvaluatorType.SCORE_STRING: ScoreStringEvalChain,
    EvaluatorType.LABELED_PAIRWISE_STRING: LabeledPairwiseStringEvalChain,
    EvaluatorType.LABELED_SCORE_STRING: LabeledScoreStringEvalChain,
    EvaluatorType.AGENT_TRAJECTORY: TrajectoryEvalChain,
    EvaluatorType.CRITERIA: CriteriaEvalChain,
    EvaluatorType.LABELED_CRITERIA: LabeledCriteriaEvalChain,
    EvaluatorType.STRING_DISTANCE: StringDistanceEvalChain,
    EvaluatorType.PAIRWISE_STRING_DISTANCE: PairwiseStringDistanceEvalChain,
    EvaluatorType.EMBEDDING_DISTANCE: EmbeddingDistanceEvalChain,
    EvaluatorType.PAIRWISE_EMBEDDING_DISTANCE: PairwiseEmbeddingDistanceEvalChain,
    EvaluatorType.JSON_VALIDITY: JsonValidityEvaluator,
    EvaluatorType.JSON_EQUALITY: JsonEqualityEvaluator,
    EvaluatorType.JSON_EDIT_DISTANCE: JsonEditDistanceEvaluator,
    EvaluatorType.JSON_SCHEMA_VALIDATION: JsonSchemaEvaluator,
    EvaluatorType.REGEX_MATCH: RegexMatchStringEvaluator,
    EvaluatorType.EXACT_MATCH: ExactMatchStringEvaluator,
}


def load_evaluator(
    evaluator: EvaluatorType,
    *,
    llm: BaseLanguageModel | None = None,
    **kwargs: Any,
) -> Chain | StringEvaluator:
    """Load the requested evaluation chain specified by a string.

    Parameters
    ----------
    evaluator : EvaluatorType
        The type of evaluator to load.
    llm : BaseLanguageModel, optional
        The language model to use for evaluation, by default None
    **kwargs : Any
        Additional keyword arguments to pass to the evaluator.

    Returns:
    -------
    Chain
        The loaded evaluation chain.

    Examples:
    --------
    >>> from langchain_classic.evaluation import load_evaluator, EvaluatorType
    >>> evaluator = load_evaluator(EvaluatorType.QA)
    """
    if evaluator not in _EVALUATOR_MAP:
        msg = (
            f"Unknown evaluator type: {evaluator}"
            f"\nValid types are: {list(_EVALUATOR_MAP.keys())}"
        )
        raise ValueError(msg)
    evaluator_cls = _EVALUATOR_MAP[evaluator]
    if issubclass(evaluator_cls, LLMEvalChain):
        try:
            try:
                from langchain_openai import ChatOpenAI
            except ImportError:
                try:
                    from langchain_community.chat_models.openai import (  # type: ignore[no-redef]
                        ChatOpenAI,
                    )
                except ImportError as e:
                    msg = (
                        "Could not import langchain_openai or fallback onto "
                        "langchain_community. Please install langchain_openai "
                        "or specify a language model explicitly. "
                        "It's recommended to install langchain_openai AND "
                        "specify a language model explicitly."
                    )
                    raise ImportError(msg) from e

            llm = llm or ChatOpenAI(model="gpt-4", seed=42, temperature=0)
        except Exception as e:
            msg = (
                f"Evaluation with the {evaluator_cls} requires a "
                "language model to function."
                " Failed to create the default 'gpt-4' model."
                " Please manually provide an evaluation LLM"
                " or check your openai credentials."
            )
            raise ValueError(msg) from e
        return evaluator_cls.from_llm(llm=llm, **kwargs)
    return evaluator_cls(**kwargs)


def load_evaluators(
    evaluators: Sequence[EvaluatorType],
    *,
    llm: BaseLanguageModel | None = None,
    config: dict | None = None,
    **kwargs: Any,
) -> list[Chain | StringEvaluator]:
    """Load evaluators specified by a list of evaluator types.

    Parameters
    ----------
    evaluators : Sequence[EvaluatorType]
        The list of evaluator types to load.
    llm : BaseLanguageModel, optional
        The language model to use for evaluation, if none is provided, a default
        ChatOpenAI gpt-4 model will be used.
    config : dict, optional
        A dictionary mapping evaluator types to additional keyword arguments,
        by default None
    **kwargs : Any
        Additional keyword arguments to pass to all evaluators.

    Returns:
    -------
    List[Chain]
        The loaded evaluators.

    Examples:
    --------
    >>> from langchain_classic.evaluation import load_evaluators, EvaluatorType
    >>> evaluators = [EvaluatorType.QA, EvaluatorType.CRITERIA]
    >>> loaded_evaluators = load_evaluators(evaluators, criteria="helpfulness")
    """
    loaded = []
    for evaluator in evaluators:
        _kwargs = config.get(evaluator, {}) if config else {}
        loaded.append(load_evaluator(evaluator, llm=llm, **{**kwargs, **_kwargs}))
    return loaded
